from ssm4epi.models.reporting_delays import (
_model,
to_log_probs,
n_iterations,
N_meis,
N_mle,
N_posterior,
key,
percentiles_of_interest,
account_for_nans,
)
import jax
from pyprojroot.here import here
jax.config.update("jax_enable_x64", True)
import pandas as pd
import jax.random as jrn
from jax import numpy as jnp, vmap
from isssm.laplace_approximation import laplace_approximation as LA
from isssm.modified_efficient_importance_sampling import (
modified_efficient_importance_sampling as MEIS,
)
from isssm.estimation import mle_pgssm, initial_theta
from isssm.importance_sampling import (
pgssm_importance_sampling,
ess_pct,
mc_integration,
prediction_percentiles,
normalize_weights,
prediction,
)
from isssm.kalman import state_mode
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams["figure.figsize"] = (20, 6)Application 3: Predictions
df = pd.read_csv(here() / "data/processed/RKI_4day_rt.csv")
df["date"] = pd.to_datetime(df["county_date"])
N_prediction = int(1e5)from isssm.typing import PGSSM
def f_pred(x, s, y):
return y[-7:].sum()
def predict_next_week(date_start, date_end, theta, key):
# indices of df that correspond to the dates
idx_start = df[df["date"] == date_start].index[0]
idx_end = df[df["date"] == date_end].index[0]
aux = (idx_end - idx_start + 1 + 7, 4)
dates = pd.to_datetime(df.iloc[idx_start : idx_end + 1]["date"])
y = jnp.asarray(df.iloc[idx_start : idx_end + 1, 1:5], dtype=jnp.float64)
y_nan = jnp.concatenate((y, jnp.full((7, 4), jnp.nan)), axis=0)
missing_inds = jnp.isnan(y_nan)
_, y_miss = account_for_nans(_model(theta, aux), y_nan, missing_inds)
_model_miss = lambda theta, aux: account_for_nans(
_model(theta, aux), y_nan, missing_inds
)[0]
theta0 = initial_theta(y_miss, _model_miss, theta, aux, n_iterations).x
key, subkey = jrn.split(key)
theta_hat = mle_pgssm(
y_miss, _model_miss, theta0, aux, n_iterations, N_mle, subkey
).x
fitted_model = _model_miss(theta_hat, aux)
proposal_la, _ = LA(y_miss, fitted_model, n_iterations)
key, subkey = jrn.split(key)
proposal_meis, _ = MEIS(
y_miss,
fitted_model,
proposal_la.z,
proposal_la.Omega,
n_iterations,
N_meis,
subkey,
)
prediction_model = _model(theta_hat, aux)
old_dist = prediction_model.dist
prediction_model = PGSSM(
*prediction_model[:-2],
lambda s, xi: old_dist(s, xi[..., 1:]),
prediction_model.xi
)
key, subkey = jrn.split(key)
mean, sd, quantiles = prediction(
f_pred,
y_miss,
proposal_meis,
fitted_model,
N_prediction,
subkey,
percentiles_of_interest,
prediction_model=prediction_model,
)
# only return last week
return (dates, y, theta_hat), jnp.hstack([mean[None], sd[None], quantiles[:]])# string to datetime
first_date_end = pd.to_datetime("2021-03-08")
first_date_start = first_date_end - pd.DateOffset(days=5 * 7 - 1)
n_weeks = 10
key = jrn.PRNGKey(5234524234)
# from showcase model
theta_manual = jnp.array(
[
-8.422721513295569,
-7.476399652760843,
-4.225735089851275,
-3.988944142986405,
-0.41375665515244603,
]
)
key, sk_predict = jrn.split(key)
date_end = first_date_end
date_start = first_date_start
theta = theta_manual
results = []
for i in range(n_weeks):
result = predict_next_week(date_start, date_end, theta, sk_predict)
date_start = date_start + pd.DateOffset(days=7)
date_end = date_end + pd.DateOffset(days=7)
(_, _, theta), _ = result
results.append(result)def result_to_pi(result):
(dates, y, theta_hat), (mean, sd, *quantiles) = result
return jnp.array([quantiles[1], mean, quantiles[-2]])pis = jnp.array([result_to_pi(result) for result in results])
idx_start = df[df["date"] == first_date_start].index[0]
idx_end = df[df["date"] == date_end].index[0]
dates = pd.to_datetime(df.iloc[idx_start : idx_end + 1]["date"])
y = jnp.asarray(df.iloc[idx_start : idx_end + 1, 1:5], dtype=jnp.float64)
plt.plot(y.sum(axis=-1).reshape((-1, 7)).sum(axis=1))
plt.plot(jnp.arange(5, 15), pis[:, 1])
plt.plot(jnp.arange(5, 15), pis[:, 0], color="grey", linestyle="--")
plt.plot(jnp.arange(5, 15), pis[:, 2], color="grey", linestyle="--")